# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# This code is adapted from https://github.com/huggingface/diffusers
# with modifications to run diffusers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
from typing import List, Literal, Optional, Union, cast

import numpy as np
import requests
from PIL import Image

import mindspore as ms

from ..image_processor import VaeImageProcessor
from ..video_processor import VideoProcessor
from .deprecation_utils import deprecate

DTYPE_MAP = {
    "float16": ms.float16,
    "float32": ms.float32,
    "bfloat16": ms.bfloat16,
    "uint8": ms.uint8,
}


def detect_image_type(data: bytes) -> str:
    if data.startswith(b"\xff\xd8"):
        return "jpeg"
    elif data.startswith(b"\x89PNG\r\n\x1a\n"):
        return "png"
    elif data.startswith(b"GIF87a") or data.startswith(b"GIF89a"):
        return "gif"
    elif data.startswith(b"BM"):
        return "bmp"
    return "unknown"


def check_inputs_decode(
    endpoint: str,
    tensor: "ms.Tensor",
    processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
    do_scaling: bool = True,
    scaling_factor: Optional[float] = None,
    shift_factor: Optional[float] = None,
    output_type: Literal["mp4", "pil", "ms"] = "pil",
    return_type: Literal["mp4", "pil", "ms"] = "pil",
    image_format: Literal["png", "jpg"] = "jpg",
    partial_postprocess: bool = False,
    input_tensor_type: Literal["binary"] = "binary",
    output_tensor_type: Literal["binary"] = "binary",
    height: Optional[int] = None,
    width: Optional[int] = None,
):
    if tensor.ndim == 3 and height is None and width is None:
        raise ValueError("`height` and `width` required for packed latents.")
    if (
        output_type == "ms"
        and return_type == "pil"
        and not partial_postprocess
        and not isinstance(processor, (VaeImageProcessor, VideoProcessor))
    ):
        raise ValueError("`processor` is required.")
    if do_scaling and scaling_factor is None:
        deprecate(
            "do_scaling",
            "1.0.0",
            "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
            standard_warn=False,
        )


def postprocess_decode(
    response: requests.Response,
    processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
    output_type: Literal["mp4", "pil", "ms"] = "pil",
    return_type: Literal["mp4", "pil", "ms"] = "pil",
    partial_postprocess: bool = False,
):
    if output_type == "ms" or (output_type == "pil" and processor is not None):
        output_tensor = response.content
        parameters = response.headers
        shape = json.loads(parameters["shape"])
        dtype = parameters["dtype"]
        mindspore_dtype = DTYPE_MAP[dtype]
        output_tensor = ms.tensor(np.frombuffer(bytearray(output_tensor)), dtype=mindspore_dtype).reshape(shape)
    if output_type == "ms":
        if partial_postprocess:
            if return_type == "pil":
                output = [Image.fromarray(image.numpy()) for image in output_tensor]
                if len(output) == 1:
                    output = output[0]
            elif return_type == "ms":
                output = output_tensor
        else:
            if processor is None or return_type == "ms":
                output = output_tensor
            else:
                if isinstance(processor, VideoProcessor):
                    output = cast(
                        List[Image.Image],
                        processor.postprocess_video(output_tensor, output_type="pil")[0],
                    )
                else:
                    output = cast(
                        Image.Image,
                        processor.postprocess(output_tensor, output_type="pil")[0],
                    )
    elif output_type == "pil" and return_type == "pil" and processor is None:
        output = Image.open(io.BytesIO(response.content)).convert("RGB")
        detected_format = detect_image_type(response.content)
        output.format = detected_format
    elif output_type == "pil" and processor is not None:
        if return_type == "pil":
            output = [
                Image.fromarray(image)
                for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
            ]
        elif return_type == "ms":
            output = output_tensor
    elif output_type == "mp4" and return_type == "mp4":
        output = response.content
    return output


def prepare_decode(
    tensor: "ms.Tensor",
    processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
    do_scaling: bool = True,
    scaling_factor: Optional[float] = None,
    shift_factor: Optional[float] = None,
    output_type: Literal["mp4", "pil", "ms"] = "pil",
    image_format: Literal["png", "jpg"] = "jpg",
    partial_postprocess: bool = False,
    height: Optional[int] = None,
    width: Optional[int] = None,
):
    headers = {}
    parameters = {
        "image_format": image_format,
        "output_type": output_type,
        "partial_postprocess": partial_postprocess,
        "shape": list(tensor.shape),
        "dtype": str(tensor.dtype).split(".")[-1],
    }
    if do_scaling and scaling_factor is not None:
        parameters["scaling_factor"] = scaling_factor
    if do_scaling and shift_factor is not None:
        parameters["shift_factor"] = shift_factor
    if do_scaling and scaling_factor is None:
        parameters["do_scaling"] = do_scaling
    elif do_scaling and scaling_factor is None and shift_factor is None:
        parameters["do_scaling"] = do_scaling
    if height is not None and width is not None:
        parameters["height"] = height
        parameters["width"] = width
    headers["Content-Type"] = "tensor/binary"
    headers["Accept"] = "tensor/binary"
    if output_type == "pil" and image_format == "jpg" and processor is None:
        headers["Accept"] = "image/jpeg"
    elif output_type == "pil" and image_format == "png" and processor is None:
        headers["Accept"] = "image/png"
    elif output_type == "mp4":
        headers["Accept"] = "text/plain"
    tensor_data = tensor.numpy().tobytes()
    return {"data": tensor_data, "params": parameters, "headers": headers}


def remote_decode(
    endpoint: str,
    tensor: "ms.Tensor",
    processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
    do_scaling: bool = True,
    scaling_factor: Optional[float] = None,
    shift_factor: Optional[float] = None,
    output_type: Literal["mp4", "pil", "ms"] = "pil",
    return_type: Literal["mp4", "pil", "ms"] = "pil",
    image_format: Literal["png", "jpg"] = "jpg",
    partial_postprocess: bool = False,
    input_tensor_type: Literal["binary"] = "binary",
    output_tensor_type: Literal["binary"] = "binary",
    height: Optional[int] = None,
    width: Optional[int] = None,
) -> Union[Image.Image, List[Image.Image], bytes, "ms.Tensor"]:
    """
    Hugging Face Hybrid Inference that allow running VAE decode remotely.

    Args:
        endpoint (`str`):
            Endpoint for Remote Decode.
        tensor (`ms.Tensor`):
            Tensor to be decoded.
        processor (`VaeImageProcessor` or `VideoProcessor`, *optional*):
            Used with `return_type="ms"`, and `return_type="pil"` for Video models.
        do_scaling (`bool`, default `True`, *optional*):
            **DEPRECATED**. **pass `scaling_factor`/`shift_factor` instead.** **still set
            do_scaling=None/do_scaling=False for no scaling until option is removed** When `True` scaling e.g. `latents
            / self.vae.config.scaling_factor` is applied remotely. If `False`, input must be passed with scaling
            applied.
        scaling_factor (`float`, *optional*):
            Scaling is applied when passed e.g. [`latents /
            self.vae.config.scaling_factor`](https://github.com/huggingface/diffusers/blob/7007febae5cff000d4df9059d9cf35133e8b2ca9/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L1083C37-L1083C77).
            - SD v1: 0.18215
            - SD XL: 0.13025
            - Flux: 0.3611
            If `None`, input must be passed with scaling applied.
        shift_factor (`float`, *optional*):
            Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`.
            - Flux: 0.1159
            If `None`, input must be passed with scaling applied.
        output_type (`"mp4"` or `"pil"` or `"ms", default `"pil"):
            **Endpoint** output type. Subject to change. Report feedback on preferred type.

            `"mp4": Supported by video models. Endpoint returns `bytes` of video. `"pil"`: Supported by image and video
            models.
                Image models: Endpoint returns `bytes` of an image in `image_format`. Video models: Endpoint returns
                `ms.Tensor` with partial `postprocessing` applied.
                    Requires `processor` as a flag (any `None` value will work).
            `"ms"`: Support by image and video models. Endpoint returns `ms.Tensor`.
                With `partial_postprocess=True` the tensor is postprocessed `uint8` image tensor.

            Recommendations:
                `"ms"` with `partial_postprocess=True` is the smallest transfer for full quality. `"ms"` with
                `partial_postprocess=False` is the most compatible with third party code. `"pil"` with
                `image_format="jpg"` is the smallest transfer overall.

        return_type (`"mp4"` or `"pil"` or `"ms", default `"pil"):
            **Function** return type.

            `"mp4": Function returns `bytes` of video. `"pil"`: Function returns `PIL.Image.Image`.
                With `output_type="pil" no further processing is applied. With `output_type="ms" a `PIL.Image.Image` is
                created.
                    `partial_postprocess=False` `processor` is required. `partial_postprocess=True` `processor` is
                    **not** required.
            `"ms"`: Function returns `ms.Tensor`.
                `processor` is **not** required. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
                denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.

        image_format (`"png"` or `"jpg"`, default `jpg`):
            Used with `output_type="pil"`. Endpoint returns `jpg` or `png`.

        partial_postprocess (`bool`, default `False`):
            Used with `output_type="ms"`. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
            denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.

        input_tensor_type (`"binary"`, default `"binary"`):
            Tensor transfer type.

        output_tensor_type (`"binary"`, default `"binary"`):
            Tensor transfer type.

        height (`int`, **optional**):
            Required for `"packed"` latents.

        width (`int`, **optional**):
            Required for `"packed"` latents.

    Returns:
        output (`Image.Image` or `List[Image.Image]` or `bytes` or `ms.Tensor`).
    """
    if input_tensor_type == "base64":
        deprecate(
            "input_tensor_type='base64'",
            "1.0.0",
            "input_tensor_type='base64' is deprecated. Using `binary`.",
            standard_warn=False,
        )
        input_tensor_type = "binary"
    if output_tensor_type == "base64":
        deprecate(
            "output_tensor_type='base64'",
            "1.0.0",
            "output_tensor_type='base64' is deprecated. Using `binary`.",
            standard_warn=False,
        )
        output_tensor_type = "binary"
    check_inputs_decode(
        endpoint,
        tensor,
        processor,
        do_scaling,
        scaling_factor,
        shift_factor,
        output_type,
        return_type,
        image_format,
        partial_postprocess,
        input_tensor_type,
        output_tensor_type,
        height,
        width,
    )
    kwargs = prepare_decode(
        tensor=tensor,
        processor=processor,
        do_scaling=do_scaling,
        scaling_factor=scaling_factor,
        shift_factor=shift_factor,
        output_type=output_type,
        image_format=image_format,
        partial_postprocess=partial_postprocess,
        height=height,
        width=width,
    )
    response = requests.post(endpoint, **kwargs)
    if not response.ok:
        raise RuntimeError(response.json())
    output = postprocess_decode(
        response=response,
        processor=processor,
        output_type=output_type,
        return_type=return_type,
        partial_postprocess=partial_postprocess,
    )
    return output


def check_inputs_encode(
    endpoint: str,
    image: Union["ms.Tensor", Image.Image],
    scaling_factor: Optional[float] = None,
    shift_factor: Optional[float] = None,
):
    pass


def postprocess_encode(
    response: requests.Response,
):
    output_tensor = response.content
    parameters = response.headers
    shape = json.loads(parameters["shape"])
    dtype = parameters["dtype"]
    mindspore_dtype = DTYPE_MAP[dtype]
    output_tensor = ms.tensor(np.frombuffer(bytearray(output_tensor)), dtype=mindspore_dtype).reshape(shape)
    return output_tensor


def prepare_encode(
    image: Union["ms.Tensor", Image.Image],
    scaling_factor: Optional[float] = None,
    shift_factor: Optional[float] = None,
):
    headers = {}
    parameters = {}
    if scaling_factor is not None:
        parameters["scaling_factor"] = scaling_factor
    if shift_factor is not None:
        parameters["shift_factor"] = shift_factor
    if isinstance(image, ms.Tensor):
        data = image.contiguous().numpy().tobytes()
        parameters["shape"] = list(image.shape)
        parameters["dtype"] = str(image.dtype).split(".")[-1]
    else:
        buffer = io.BytesIO()
        image.save(buffer, format="PNG")
        data = buffer.getvalue()
    return {"data": data, "params": parameters, "headers": headers}


def remote_encode(
    endpoint: str,
    image: Union["ms.Tensor", Image.Image],
    scaling_factor: Optional[float] = None,
    shift_factor: Optional[float] = None,
) -> "ms.Tensor":
    """
    Hugging Face Hybrid Inference that allow running VAE encode remotely.

    Args:
        endpoint (`str`):
            Endpoint for Remote Decode.
        image (`ms.Tensor` or `PIL.Image.Image`):
            Image to be encoded.
        scaling_factor (`float`, *optional*):
            Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`].
            - SD v1: 0.18215
            - SD XL: 0.13025
            - Flux: 0.3611
            If `None`, input must be passed with scaling applied.
        shift_factor (`float`, *optional*):
            Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`.
            - Flux: 0.1159
            If `None`, input must be passed with scaling applied.

    Returns:
        output (`ms.Tensor`).
    """
    check_inputs_encode(
        endpoint,
        image,
        scaling_factor,
        shift_factor,
    )
    kwargs = prepare_encode(
        image=image,
        scaling_factor=scaling_factor,
        shift_factor=shift_factor,
    )
    response = requests.post(endpoint, **kwargs)
    if not response.ok:
        raise RuntimeError(response.json())
    output = postprocess_encode(
        response=response,
    )
    return output
